# SIU KING WAI SM4701 Deepstory
import re
import copy
import spacy
import librosa
import numpy as np

from unidecode import unidecode
from modules.dctts import hp
from pydub import AudioSegment, effects


def quote_boundaries(doc):
    for token in doc[:-1]:
        # if token.text == "“" or token.text == "”":
        #     doc[token.i + 1].is_sent_start = True
        if token.text == "“":
            doc[token.i + 1].is_sent_start = True
    return doc


nlp = spacy.load('en_core_web_sm')
nlp.add_pipe(quote_boundaries, before="parser")
nlp_no_comma = copy.deepcopy(nlp)
sentencizer = nlp.create_pipe("sentencizer")
sentencizer.punct_chars.add(',')
sentencizer_no_comma = nlp_no_comma.create_pipe("sentencizer")
nlp.add_pipe(sentencizer, first=True)
nlp_no_comma.add_pipe(sentencizer_no_comma, first=True)


def normalize_text(text):
    """Normalize text so that some punctuations that indicate pauses will be replaced as commas"""
    replace_list = [
        [r'(\.\.\.)$|…$', '.'],
        [r'\(|\)|:|;| “|(\s*-+\s+)|(\s+-+\s*)|\s*-{2,}\s*|(\.\.\.)|…|—', ', '],
        [r'\s*,[^\w]*,\s*', ', '],  # capture multiple commas
        [r'\s*,\s*', ', '],  # format commas
        [r'\.,', '.'],
        [r'[‘’“”]', '']  # strip quote
    ]
    for regex, replacement in replace_list:
        text = re.sub(regex, replacement, text)
    text = unidecode(text)  # Get rid of the accented characters
    text = text.lower()
    text = re.sub(f"[^{hp.vocab}]", " ", text)
    text = re.sub(r' +', ' ', text).strip()
    return text


def fix_text(text):
    """fix text for pasting content from the book"""
    replace_list = [
        [r'(\w)’(\w)', r"\1'\2"],  # fix apostrophe for content from books
    ]
    for regex, replacement in replace_list:
        text = re.sub(regex, replacement, text)
    text = re.sub(r' +', ' ', text)

    return text


def trim_text(generated_text, max_sentences=0, script=False):
    """trim unfinished sentence generated by GPT2"""
    # remove this replacement character for utf-8, a bug?
    generated_text = generated_text.replace(b'\xef\xbf\xbd'.decode('utf-8'), '')
    if script:
        generated_text = generated_text.replace('\n', '')
    text_list = re.findall(r'.*?[.!\?…—][’”]*', generated_text, re.DOTALL)
    if script:
        text_list = ['\n' + text if text[0].isupper() else text for text in text_list]

    # if limit the max_sentence
    if max_sentences:
        # find all sentences and parsed as list and select the first nth items and join them back
        return ''.join(text_list[:max_sentences])
    else:
        return ''.join(text_list)


    # backup...
    # # select until the last punctuation using regex, and create an nlp object for counting sentences
    # text_list = [*nlp_no_comma(re.findall(r'.*[.!\?’”]', generated_text, re.DOTALL)[0]).sents]
    # # figure out how to select max sentence(which structure)
    # text_list = re.findall(r'.*?[.!\?]|.*\w+', generated_text, re.DOTALL)
    # for i in reversed(range(1, len(text_list))):
    #     try:
    #         while not text_list[i][0].isalpha() and text_list[i][0] != '“' and text_list[i][0] != '‘' and text_list[i][0] != ' ':
    #             text_list[i - 1] = text_list[i - 1] + text_list[i][0]
    #             text_list[i] = text_list[i][1:]
    #             if not text_list:
    #                 break
    #     except IndexError:
    #         print('ok')
    # if not any(text_list[-1][-1] == x for x in ['.', '!', '?']):
    #     del text_list[-1]
    # if max_sentences:
    #     text_list = [text.text for i, text in enumerate(text_list) if i < max_sentences]
    # else:
    #     text_list = [text.text for text in text_list]
    # return ' '.join(text_list)


def separate(text, n_gram, comma, max_len=30):
    _nlp = nlp if comma else nlp_no_comma
    lines = []
    line = ''
    counter = 0
    for sent in _nlp(text).sents:
        if sent.text:
            if counter == 0:
                line = sent.text
            else:
                line = f'{line} {sent.text}'
            counter += 1

            if counter == n_gram:
                lines.append(_nlp(line))
                line = ''
                counter = 0

    # for remaining sentences
    if line:
        lines.append(_nlp(line))

    return lines


def get_duration(second):
    return int(hp.sr * second)


def normalize_audio(wav):
    # normalize the audio with pydub
    audioseg = AudioSegment(wav.tobytes(), sample_width=2, frame_rate=hp.sr, channels=1)
    # normalized = effects.normalize(audioseg, self.norm_factor)
    normalized = audioseg.apply_gain(-30 - audioseg.dBFS)
    wav = np.array(normalized.get_array_of_samples())
    return wav


# from my audio processing project
def split_audio_to_list(source, preemph=True, preemphasis=0.8, min_diff=1500, min_size=get_duration(1), db=80):
    if preemph:
        source = np.append(source[0], source[1:] - preemphasis * source[:-1])
    split_list = librosa.effects.split(source, top_db=db).tolist()
    i = len(split_list) - 1
    while i > 0:
        if split_list[i][-1] - split_list[i][0] > min_size:
            now = split_list[i][0]
            prev = split_list[i - 1][1]
            diff = now - prev
            if diff < min_diff:
                split_list[i - 1] = [split_list[i - 1][0], split_list.pop(i)[1]]
        else:
            split_list.pop(i)
        i -= 1

    # make sure nothing is trimmed away
    split_list[0][0] = 0
    split_list[-1][1] = len(source)
    for i in reversed(range(len(split_list))):
        if i != 0:
            split_list[i][0] = split_list[i - 1][1]

    return split_list
